package com.jstarcraft.ai.jsat.classifiers.linear;

import java.util.Arrays;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.SimpleWeightVectorModel;
import com.jstarcraft.ai.jsat.classifiers.BaseUpdateableClassifier;
import com.jstarcraft.ai.jsat.classifiers.CategoricalData;
import com.jstarcraft.ai.jsat.classifiers.CategoricalResults;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.distributions.Distribution;
import com.jstarcraft.ai.jsat.distributions.LogUniform;
import com.jstarcraft.ai.jsat.exceptions.FailedToFitException;
import com.jstarcraft.ai.jsat.linear.DenseVector;
import com.jstarcraft.ai.jsat.linear.IndexValue;
import com.jstarcraft.ai.jsat.linear.ScaledVector;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.lossfunctions.HingeLoss;
import com.jstarcraft.ai.jsat.lossfunctions.LossC;
import com.jstarcraft.ai.jsat.lossfunctions.LossFunc;
import com.jstarcraft.ai.jsat.lossfunctions.LossMC;
import com.jstarcraft.ai.jsat.lossfunctions.LossR;
import com.jstarcraft.ai.jsat.math.decayrates.DecayRate;
import com.jstarcraft.ai.jsat.math.decayrates.PowerDecay;
import com.jstarcraft.ai.jsat.math.optimization.stochastic.GradientUpdater;
import com.jstarcraft.ai.jsat.math.optimization.stochastic.SimpleSGD;
import com.jstarcraft.ai.jsat.parameters.Parameterized;
import com.jstarcraft.ai.jsat.regression.BaseUpdateableRegressor;
import com.jstarcraft.ai.jsat.regression.RegressionDataSet;
import com.jstarcraft.ai.jsat.regression.UpdateableRegressor;

/**
 * LinearSGD learns either a classification or regression problem depending on
 * the {@link #setLoss(com.jstarcraft.ai.jsat.lossfunctions.LossFunc) loss
 * function &#x2113;(w,x)} used. The solution attempts to minimize
 * <big>&sum;</big><sub>i</sub> &#x2113;(w,x<sub>i</sub>) +
 * {@link #setLambda0(double) &lambda;<sub>0</sub>}/2
 * ||w||<sub>2</sub><sup>2</sup> + {@link #setLambda1(double)
 * &lambda;<sub>1</sub>} ||w||<sub>1</sub>, and is trained by Stochastic
 * Gradient Descent. <br>
 * <br>
 * <br>
 * NOTE: To support L<sub>1</sub> regularization with sparse results and online
 * learning at the same time, the normalization of the regularization penalty by
 * the number of data points is not present in the implementation at this time.
 * Setting {@link #setLambda1(double) &lambda;<sub>1</sub>} to the desired value
 * divided by the number of unique data points in the whole set will result in
 * the correct regularization penalty being applied.
 * 
 * See:
 * <ul>
 * <li>Tsuruoka, Y., Tsujii, J.,&amp;Ananiadou, S. (2009). <i>Stochastic
 * gradient descent training for L1-regularized log-linear models with
 * cumulative penalty</i>. Proceedings of the Joint Conference of the 47th
 * Annual Meeting of the ACL and the 4th International Joint Conference on
 * Natural Language Processing of the AFNLP, 1, 477.
 * doi:10.3115/1687878.1687946</li>
 * </ul>
 * 
 * @author Edward Raff
 */
public class LinearSGD extends BaseUpdateableClassifier implements UpdateableRegressor, Parameterized, SimpleWeightVectorModel {

    private static final long serialVersionUID = -59695592724956535L;
    private LossFunc loss;
    private GradientUpdater gradientUpdater;
    private double eta;
    private DecayRate decay;
    private Vec[] ws;
    private GradientUpdater[] gus;
    private double[] bs;
    private int time;
    private double lambda0;
    private double lambda1;
    private double l1U;
    private double[][] l1Q;
    private boolean useBias = true;

    /**
     * Creates a new LinearSGD learner for multi-class classification problems.
     */
    public LinearSGD() {
        this(new HingeLoss(), 1e-4, 0.0);
    }

    /**
     * Creates a new LinearSGD learner
     * 
     * @param loss    the loss function to use
     * @param lambda0 the L<sub>2</sub> regularization term
     * @param lambda1 the L<sub>1</sub> regularization term
     */
    public LinearSGD(LossFunc loss, double lambda0, double lambda1) {
        this(loss, 0.001, new PowerDecay(1, 0.1), lambda0, lambda1);
    }

    /**
     * Creates a new LinearSGD learner.
     * 
     * @param loss    the loss function to use
     * @param eta     the initial learning rate
     * @param decay   the decay rate for &eta;
     * @param lambda0 the L<sub>2</sub> regularization term
     * @param lambda1 the L<sub>1</sub> regularization term
     */
    public LinearSGD(LossFunc loss, double eta, DecayRate decay, double lambda0, double lambda1) {
        setLoss(loss);
        setEta(eta);
        setEtaDecay(decay);
        setGradientUpdater(new SimpleSGD());
        setLambda0(lambda0);
        setLambda1(lambda1);
    }

    /**
     * Copy constructor
     * 
     * @param toClone the object to copy
     */
    public LinearSGD(LinearSGD toClone) {
        this.loss = toClone.loss.clone();
        this.eta = toClone.eta;
        this.decay = toClone.decay.clone();
        this.time = toClone.time;
        this.lambda0 = toClone.lambda0;
        this.lambda1 = toClone.lambda1;
        this.l1U = toClone.l1U;
        this.useBias = toClone.useBias;
        this.gradientUpdater = toClone.gradientUpdater;
        if (toClone.l1Q != null) {
            this.l1Q = new double[toClone.l1Q.length][];
            for (int i = 0; i < toClone.l1Q.length; i++)
                this.l1Q[i] = Arrays.copyOf(toClone.l1Q[i], toClone.l1Q[i].length);
        }
        if (toClone.ws != null) {
            this.ws = new Vec[toClone.ws.length];
            this.bs = new double[toClone.bs.length];
            this.gus = new GradientUpdater[toClone.gus.length];
            for (int i = 0; i < ws.length; i++) {
                this.ws[i] = toClone.ws[i].clone();
                this.bs[i] = toClone.bs[i];
                this.gus[i] = toClone.gus[i].clone();
            }
        }
    }

    /**
     * Sets the method that will be used to update the weight vectors given their
     * gradient information.
     * 
     * @param gradientUpdater the method to use for updating the weight vectors from
     *                        the gradient
     */
    public void setGradientUpdater(GradientUpdater gradientUpdater) {
        if (gradientUpdater == null)
            throw new IllegalArgumentException("Gradient updater must be non-null");
        this.gradientUpdater = gradientUpdater;
    }

    /**
     * 
     * @return the method to use for updating the weight vectors from the gradient
     */
    public GradientUpdater getGradientUpdater() {
        return gradientUpdater;
    }

    /**
     * Sets the rate at which {@link #setEta(double) &eta;} is decayed at each
     * update.
     * 
     * @param decay the decay rate to use
     */
    public void setEtaDecay(DecayRate decay) {
        this.decay = decay;
    }

    /**
     * Returns the decay rate in use
     * 
     * @return the decay rate in use
     */
    public DecayRate getEtaDecay() {
        return decay;
    }

    /**
     * Sets the initial learning rate &eta; to use. It should generally be in (0,
     * 1), but any positive value is acceptable.
     * 
     * @param eta the learning rate to use.
     */
    public void setEta(double eta) {
        if (eta <= 0 || Double.isNaN(eta) || Double.isInfinite(eta))
            throw new IllegalArgumentException("eta must be a positive constant, not " + eta);
        this.eta = eta;
    }

    /**
     * Returns the current learning rate in use
     * 
     * @return the current learning rate in use
     */
    public double getEta() {
        return eta;
    }

    /**
     * Sets the loss function used for the model. The loss function controls whether
     * or not regression, binary classification, or multi-class classification is
     * supported.
     * 
     * @param loss the loss function to use
     */
    public void setLoss(LossFunc loss) {
        this.loss = loss;
    }

    /**
     * Returns the loss function in use
     * 
     * @return the loss function in use
     */
    public LossFunc getLoss() {
        return loss;
    }

    /**
     * &lambda;<sub>0</sub> controls the L<sub>2</sub> regularization penalty.
     * 
     * @param lambda0 the L<sub>2</sub> regularization penalty to use
     */
    public void setLambda0(double lambda0) {
        if (lambda0 < 0 || Double.isNaN(lambda0) || Double.isInfinite(lambda0))
            throw new IllegalArgumentException("Lambda0 must be non-negative, not " + lambda0);
        this.lambda0 = lambda0;
    }

    /**
     * Returns the L<sub>2</sub> regularization term in use
     * 
     * @return the L<sub>2</sub> regularization term in use
     */
    public double getLambda0() {
        return lambda0;
    }

    /**
     * &lambda;<sub>1</sub> controls the L<sub>1</sub> regularization penalty.
     * 
     * @param lambda1 the L<sub>1</sub> regularization penalty to use
     */
    public void setLambda1(double lambda1) {
        if (lambda1 < 0 || Double.isNaN(lambda1) || Double.isInfinite(lambda1))
            throw new IllegalArgumentException("Lambda1 must be non-negative, not " + lambda1);
        this.lambda1 = lambda1;
    }

    /**
     * Returns the L<sub>1</sub> regularization term in use
     * 
     * @return the L<sub>1</sub> regularization term in use
     */
    public double getLambda1() {
        return lambda1;
    }

    /**
     * Sets whether or not an implicit bias term will be added to the data set
     * 
     * @param useBias {@code true} to add an implicit bias term
     */
    public void setUseBias(boolean useBias) {
        this.useBias = useBias;
    }

    /**
     * Returns whether or not an implicit bias term is in use
     * 
     * @return {@code true} if a bias term is in use
     */
    public boolean isUseBias() {
        return useBias;
    }

    @Override
    public LinearSGD clone() {
        return new LinearSGD(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (!(loss instanceof LossC))
            throw new FailedToFitException("Loss function " + loss.getClass().getSimpleName() + " only supports regression");
        if (predicting.getNumOfCategories() == 2) {
            ws = new Vec[1];
            bs = new double[1];
            gus = new GradientUpdater[1];
        } else {
            if (!(loss instanceof LossMC))
                throw new FailedToFitException("Loss function " + loss.getClass().getSimpleName() + " only supports binary classification");
            ws = new Vec[predicting.getNumOfCategories()];
            bs = new double[predicting.getNumOfCategories()];
            gus = new GradientUpdater[predicting.getNumOfCategories()];
        }
        setUpShared(numericAttributes);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes) {
        if (!(loss instanceof LossR))
            throw new FailedToFitException("Loss function " + loss.getClass().getSimpleName() + "does not support regression");

        ws = new Vec[1];
        bs = new double[1];
        gus = new GradientUpdater[1];
        setUpShared(numericAttributes);
    }

    private void setUpShared(int numericAttributes) {
        if (numericAttributes <= 0)
            throw new FailedToFitException("LinearSGD requires numeric features to use");
        for (int i = 0; i < ws.length; i++) {
            ws[i] = new ScaledVector(new DenseVector(numericAttributes));
            gus[i] = gradientUpdater.clone();
            gus[i].setup(ws[i].length());
        }
        time = 0;
        l1U = 0;
        if (lambda1 > 0)
            l1Q = new double[ws.length][ws[0].length()];
        else
            l1Q = null;
    }

    @Override
    public void update(DataPoint dataPoint, double weight, int targetClass) {
        final double eta_t = decay.rate(time++, eta);
        Vec x = dataPoint.getNumericalValues();

        applyL2Reg(eta_t);

        // apply gradient updates
        if (ws.length == 1) {
            final double y = targetClass * 2 - 1;
            final double lossD = ((LossC) loss).getDeriv(ws[0].dot(x) + bs[0], y);
            performGradientUpdate(0, eta_t, lossD, x);
        } else {
            Vec pred = new DenseVector(ws.length);
            for (int i = 0; i < ws.length; i++)
                pred.set(i, ws[i].dot(x) + bs[i]);
            ((LossMC) loss).process(pred, pred);
            ((LossMC) loss).deriv(pred, pred, targetClass);
            for (IndexValue iv : pred) {
                final int i = iv.getIndex();
                final double lossD = iv.getValue();
                performGradientUpdate(i, eta_t, lossD, x);
            }
        }

        applyL1Reg(eta_t, x);
    }

    /**
     * 
     * @param i     the index of the weight vector array to update
     * @param eta_t the learning rate to use
     * @param lossD the loss for the specified weight vector
     * @param x     the input vector the loss was incurred on
     */
    private void performGradientUpdate(final int i, final double eta_t, final double lossD, Vec x) {
        final Vec grad = new ScaledVector(lossD, x);
        if (useBias)
            bs[i] -= gus[i].update(ws[i], grad, eta_t, bs[i], lossD);
        else
            gus[i].update(ws[i], grad, eta_t);
    }

    @Override
    public void update(DataPoint dataPoint, double weight, double targetValue) {
        final double eta_t = decay.rate(time++, eta);
        Vec x = dataPoint.getNumericalValues();

        applyL2Reg(eta_t);

        final double lossD = ((LossR) loss).getDeriv(ws[0].dot(x) + bs[0], targetValue);

        performGradientUpdate(0, eta_t, lossD, x);

        applyL1Reg(eta_t, x);
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        Vec x = data.getNumericalValues();
        if (ws.length == 1)
            return ((LossC) loss).getClassification(ws[0].dot(x) + bs[0]);
        else {
            Vec pred = new DenseVector(ws.length);
            for (int i = 0; i < ws.length; i++)
                pred.set(i, ws[i].dot(x) + bs[i]);
            ((LossMC) loss).process(pred, pred);
            return ((LossMC) loss).getClassification(pred);
        }
    }

    @Override
    public double regress(DataPoint data) {
        Vec x = data.getNumericalValues();
        return ((LossR) loss).getRegression(ws[0].dot(x) + bs[0]);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    /**
     * Applies L2 regularization to the model
     * 
     * @param eta_t the learning rate in use
     */
    private void applyL2Reg(final double eta_t) {
        if (lambda0 > 0)// apply L2 regularization
            for (Vec v : ws)
                v.mutableMultiply(1 - eta_t * lambda0);
    }

    /**
     * Applies L1 regularization to the model
     * 
     * @param eta_t the learning rate in use
     * @param x     the input vector the update is from
     */
    private void applyL1Reg(final double eta_t, Vec x) {
        // apply l1 regularization
        if (lambda1 > 0) {
            l1U += eta_t * lambda1;// line 6: in Tsuruoka et al paper, figure 2
            for (int k = 0; k < ws.length; k++) {
                final Vec w_k = ws[k];
                final double[] l1Q_k = l1Q[k];
                for (IndexValue iv : x) {
                    final int i = iv.getIndex();
                    // see "APPLYPENALTY(i)" on line 15: from Figure 2 in Tsuruoka et al paper
                    final double z = w_k.get(i);
                    double newW_i = 0;
                    if (z > 0)
                        newW_i = Math.max(0, z - (l1U + l1Q_k[i]));
                    else if (z < 0)
                        newW_i = Math.min(0, z + (l1U - l1Q_k[i]));
                    l1Q_k[i] += (newW_i - z);
                    w_k.set(i, newW_i);
                }
            }
        }
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        train(dataSet);
    }

    @Override
    public void train(RegressionDataSet dataSet) {
        BaseUpdateableRegressor.trainEpochs(dataSet, this, getEpochs());
    }

    @Override
    public Vec getRawWeight(int index) {
        return ws[index];
    }

    @Override
    public double getBias(int index) {
        return bs[index];
    }

    @Override
    public int numWeightsVecs() {
        return ws.length;
    }

    /**
     * Guess the distribution to use for the regularization term
     * {@link #setLambda0(double) &lambda;<sub>0</sub>} .
     *
     * @param d the data set to get the guess for
     * @return the guess for the &lambda;<sub>0</sub> parameter
     */
    public static Distribution guessLambda0(DataSet d) {
        return new LogUniform(1e-7, 1e-2);
    }

    /**
     * Guess the distribution to use for the regularization term
     * {@link #setLambda0(double) &lambda;<sub>1</sub>} .
     *
     * @param d the data set to get the guess for
     * @return the guess for the &lambda;<sub>1</sub> parameter
     */
    public static Distribution guessLambda1(DataSet d) {
        int N = d.size();
        return new LogUniform(1e-7 / N, 1e-3 / N);
    }

}
